{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys, os\n",
    "if 'google.colab' in sys.modules and not os.path.exists('.setup_complete'):\n",
    "    !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/setup_colab.sh -O- | bash\n",
    "\n",
    "    !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week09_policy_II/td3_and_sac/logger.py\n",
    "\n",
    "    !pip -q install gymnasium[mujoco]\n",
    "    !pip -q install tensorboardX\n",
    "\n",
    "    !touch .setup_complete\n",
    "\n",
    "# This code creates a virtual display to draw game images on.\n",
    "# It will have no effect if your machine has a monitor.\n",
    "if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:\n",
    "    !bash ../xvfb start\n",
    "    os.environ['DISPLAY'] = ':1'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Continuous Control\n",
    "\n",
    "\n",
    "In this notebook you will solve continuous control environment using either [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf) or [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1801.01290.pdf). Both are off-policy algorithms that are current go-to algorithms for continuous control tasks.\n",
    "\n",
    "**Select one** of these two algorithms (TD3 or SAC) to implement. Both algorithms are extensions of basic [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/abs/1509.02971) algorithm, and DDPG is kind of \"DQN with another neural net approximating greedy policy\", and all that differs is a set of stabilization tricks:\n",
    "* TD3 trains deterministic policy, while SAC uses *stochastic policy*. This means that for SAC you can solve exploration-exploitation trade-off by simple sampling from policy, while in TD3 you will have to add noise to your actions.\n",
    "* TD3 proposes to stabilize targets by adding a *clipped noise* to actions, which slightly prevents overestimation. In SAC, we formally switch to formalism of Maximum Entropy RL and add *entropy bonus* into our value function.\n",
    "\n",
    "Also both algorithms utilize a *twin trick*: train two critics and use pessimistic targets by taking minimum from two proposals. Standard trick with target networks is also necessary. We will go through all these tricks step-by-step.\n",
    "\n",
    "SAC is probably less clumsy scheme than TD3, but requires a bit more code to implement. More detailed description of algorithms can be found in Spinning Up documentation:\n",
    "* on [DDPG](https://spinningup.openai.com/en/latest/algorithms/ddpg.html)\n",
    "* on [TD3](https://spinningup.openai.com/en/latest/algorithms/td3.html)\n",
    "* on [SAC](https://spinningup.openai.com/en/latest/algorithms/sac.html)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gymnasium as gym\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, we will create an instance of the environment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-16T18:41:00.003174Z",
     "start_time": "2020-09-16T18:40:59.921640Z"
    }
   },
   "outputs": [],
   "source": [
    "env = gym.make(\"Ant-v4\", render_mode=\"rgb_array\")\n",
    "\n",
    "# we want to look inside\n",
    "env.reset()\n",
    "\n",
    "# examples of states and actions\n",
    "print(\"observation space: \", env.observation_space,\n",
    "      \"\\nobservations:\", env.reset()[0])\n",
    "print(\"action space: \", env.action_space,\n",
    "      \"\\naction_sample: \", env.action_space.sample())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.imshow(env.render())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's run random policy and see how it looks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RandomActor():\n",
    "    def get_action(self, states):\n",
    "        assert len(states.shape) == 1, \"can't work with batches\"\n",
    "        return env.action_space.sample()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s, _ = env.reset()\n",
    "rewards_per_step = []\n",
    "actor = RandomActor()\n",
    "\n",
    "for i in range(10000):\n",
    "    a = actor.get_action(s)\n",
    "    s, r, terminated, truncated, _ = env.step(a)\n",
    "\n",
    "    rewards_per_step.append(r)\n",
    "\n",
    "    if terminated or truncated:\n",
    "        s, _ = env.reset()\n",
    "        print(\"done: \", i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "So, basically most episodes are 1000 steps long (then happens termination by time), though sometimes we are terminated earlier if simulation discovers some obvious reasons to think that we crashed our ant. Important thing about continuous control tasks like this is that we receive non-trivial signal at each step: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rewards_per_step[100:110]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This dense signal will guide our optimizations. It also partially explains why off-policy algorithms are more effective and sample-efficient than on-policy algorithms like PPO: 1-step targets are already quite informative."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will add only one wrapper to our environment to simply write summaries, mainly, the total reward during an episode."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from logger import TensorboardSummaries as Summaries\n",
    "\n",
    "env = gym.make(\"Ant-v4\", render_mode=\"rgb_array\")\n",
    "env = Summaries(env, \"MyFirstWalkingAnt\");\n",
    "\n",
    "state_dim = env.observation_space.shape[0]  # dimension of state space (27 numbers)\n",
    "action_dim = env.action_space.shape[0]      # dimension of action space (8 numbers)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's start with *critic* model. On the one hand, it will function as an approximation of $Q^*(s, a)$, on the other hand it evaluates current actor $\\pi$ and can be viewed as $Q^{\\pi}(s, a)$. This critic will take both state $s$ and action $a$ as input and output a scalar value. Recommended architecture is 3-layered MLP.\n",
    "\n",
    "**Danger:** when models have a scalar output it is a good rule to squeeze it to avoid unexpected broadcasting, since [batch_size, 1] broadcasts with many tensor sizes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "class Critic(nn.Module):\n",
    "    def __init__(self, state_dim, action_dim):\n",
    "        super().__init__()        \n",
    "\n",
    "        <YOUR CODE>\n",
    "\n",
    "    def get_qvalues(self, states, actions):\n",
    "        '''\n",
    "        input:\n",
    "            states - tensor, (batch_size x features)\n",
    "            actions - tensor, (batch_size x actions_dim)\n",
    "        output:\n",
    "            qvalues - tensor, critic estimation, (batch_size)\n",
    "        '''\n",
    "        qvalues = <YOUR CODE>\n",
    "\n",
    "        assert len(qvalues.shape) == 1 and qvalues.shape[0] == states.shape[0]\n",
    "        \n",
    "        return qvalues"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, let's define a policy, or an actor $\\pi$. Use architecture, similar to critic (3-layered MLP). The output depends on algorithm:\n",
    "\n",
    "For **TD3**, model *deterministic policy*. You should output `action_dim` numbers in range $[-1, 1]$. Unfortunately, deterministic policies lead to problems with stability and exploration, so we will need three \"modes\" of how this policy can be operating:\n",
    "* First one - greedy - is a simple feedforward pass through network that will be used to train the actor.\n",
    "* Second one - exploration mode - is when we need to add noise (e.g. Gaussian) to our actions to collect more diverse data. \n",
    "* Third mode - \"clipped noised\" - will be used when we will require a target for critic, where we need to somehow \"noise\" our actor output, but not too much, so we add *clipped noise* to our output:\n",
    "$$\\pi_{\\theta}(s) + \\varepsilon, \\quad \\varepsilon = \\operatorname{clip}(\\epsilon, -0.5, 0.5), \\epsilon \\sim \\mathcal{N}(0, \\sigma^2 I)$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-16T18:41:06.246418Z",
     "start_time": "2020-09-16T18:41:05.841255Z"
    }
   },
   "outputs": [],
   "source": [
    "# template for TD3; template for SAC is below\n",
    "class TD3_Actor(nn.Module):\n",
    "    def __init__(self, state_dim, action_dim):\n",
    "        super().__init__()        \n",
    "\n",
    "        <YOUR CODE>\n",
    "\n",
    "    def get_action(self, states, std_noise=0.1):\n",
    "        '''\n",
    "        Used to collect data by interacting with environment,\n",
    "        so your have to add some noise to actions.\n",
    "        input:\n",
    "            states - numpy, (batch_size x features)\n",
    "        output:\n",
    "            actions - numpy, (batch_size x actions_dim)\n",
    "        '''\n",
    "        # no gradient computation is required here since we will use this only for interaction\n",
    "        with torch.no_grad():\n",
    "            actions = <YOUR CODE>\n",
    "            \n",
    "            assert isinstance(actions, (list,np.ndarray)), \"convert actions to numpy to send into env\"\n",
    "            assert actions.max() <= 1. and actions.min() >= -1, \"actions must be in the range [-1, 1]\"\n",
    "            return actions\n",
    "    \n",
    "    def get_best_action(self, states):\n",
    "        '''\n",
    "        Will be used to optimize actor. Requires differentiable w.r.t. parameters actions.\n",
    "        input:\n",
    "            states - PyTorch tensor, (batch_size x features)\n",
    "        output:\n",
    "            actions - PyTorch tensor, (batch_size x actions_dim)\n",
    "        '''\n",
    "        actions = <YOUR CODE>\n",
    "        \n",
    "        assert actions.requires_grad, \"you must be able to compute gradients through actions\"\n",
    "        return actions\n",
    "    \n",
    "    def get_target_action(self, states, std_noise=0.2, clip_eta=0.5):\n",
    "        '''\n",
    "        Will be used to create target for critic optimization.\n",
    "        Returns actions with added \"clipped noise\".\n",
    "        input:\n",
    "            states - PyTorch tensor, (batch_size x features)\n",
    "        output:\n",
    "            actions - PyTorch tensor, (batch_size x actions_dim)\n",
    "        '''\n",
    "        # no gradient computation is required here since we will use this only for interaction\n",
    "        with torch.no_grad():\n",
    "            actions = <YOUR CODE>\n",
    "            \n",
    "            # actions can fly out of [-1, 1] range after added noise\n",
    "            return actions.clamp(-1, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For **SAC**, model *gaussian policy*. This means policy distribution is going to be multivariate normal with diagonal covariance. The policy head will predict the mean and covariance, and it should be guaranteed that covariance is non-negative. **Important:** the way you model covariance strongly influences optimization procedure, so here are some options: let $f_{\\theta}$ be the output of covariance head, then:\n",
    "* use exponential function $\\sigma(s) = \\exp(f_{\\theta}(s))$\n",
    "* transform output to $[-1, 1]$ using `tanh`, then project output to some interval $[m, M]$, where $m = -20$, $M = 2$ and then use exponential function. This will guarantee the range of modeled covariance is adequate. So, the resulting formula is:\n",
    "$$\\sigma(s) = \\exp^{m + 0.5(M - m)(\\tanh(f_{\\theta}(s)) + 1)}$$\n",
    "* `softplus` operation $\\sigma(s) = \\log(1 + \\exp^{f_{\\theta}(s)})$ seems to work poorly here. o_O\n",
    "\n",
    "**Note**: `torch.distributions.Normal` already has everything you will need to work with such policy after you modeled mean and covariance, i.e. sampling via reparametrization trick (see `rsample` method) and compute log probability (see `log_prob` method).\n",
    "\n",
    "There is one more problem with gaussian distribution. We need to force our actions to be in $[-1, 1]$ bound. To achieve this, model unbounded gaussian $\\mathcal{N}(\\mu_{\\theta}(s), \\sigma_{\\theta}(s)^2I)$, where $\\mu$ can be arbitrary. Then every time you have samples $u$ from this gaussian policy, squash it using $\\operatorname{tanh}$ function to get a sample from $[-1, 1]$:\n",
    "$$u \\sim \\mathcal{N}(\\mu, \\sigma^2I)$$\n",
    "$$a = \\operatorname{tanh}(u)$$\n",
    "\n",
    "**Important:** after that you are required to use change of variable formula every time you compute likelihood (see appendix C in [paper on SAC](https://arxiv.org/pdf/1801.01290.pdf) for details):\n",
    "$$\\log p(a \\mid \\mu, \\sigma) = \\log p(u \\mid \\mu, \\sigma) - \\sum_{i = 1}^D \\log (1 - \\operatorname{tanh}^2(u_i)),$$\n",
    "where $D$ is `action_dim`. In practice, add something like 1e-6 inside logarithm to protect from computational instabilities."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-16T18:41:06.246418Z",
     "start_time": "2020-09-16T18:41:05.841255Z"
    }
   },
   "outputs": [],
   "source": [
    "# template for SAC\n",
    "from torch.distributions import Normal\n",
    "\n",
    "class SAC_Actor(nn.Module):\n",
    "    def __init__(self, state_dim, action_dim):\n",
    "        super().__init__()        \n",
    "\n",
    "        <YOUR CODE>\n",
    "    \n",
    "    def apply(self, states):\n",
    "        '''\n",
    "        For given batch of states samples actions and also returns its log prob.\n",
    "        input:\n",
    "            states - PyTorch tensor, (batch_size x features)\n",
    "        output:\n",
    "            actions - PyTorch tensor, (batch_size x action_dim)\n",
    "            log_prob - PyTorch tensor, (batch_size)\n",
    "        '''\n",
    "        <YOUR CODE>      \n",
    "        \n",
    "        return actions, log_prob  \n",
    "\n",
    "    def get_action(self, states):\n",
    "        '''\n",
    "        Used to interact with environment by sampling actions from policy\n",
    "        input:\n",
    "            states - numpy, (batch_size x features)\n",
    "        output:\n",
    "            actions - numpy, (batch_size x actions_dim)\n",
    "        '''\n",
    "        # no gradient computation is required here since we will use this only for interaction\n",
    "        with torch.no_grad():\n",
    "            \n",
    "            # hint: you can use `apply` method here\n",
    "            actions = <YOUR CODE>\n",
    "            \n",
    "            assert isinstance(actions, (list,np.ndarray)), \"convert actions to numpy to send into env\"\n",
    "            assert actions.max() <= 1. and actions.min() >= -1, \"actions must be in the range [-1, 1]\"\n",
    "            return actions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ReplayBuffer\n",
    "\n",
    "The same as in DQN. You can copy code from your DQN assignment, just check that it works fine with continuous actions (probably it is). \n",
    "\n",
    "Let's recall the interface:\n",
    "* `exp_replay.add(obs, act, rw, next_obs, done)` - saves (s,a,r,s',done) tuple into the buffer\n",
    "* `exp_replay.sample(batch_size)` - returns observations, actions, rewards, next_observations and is_done for `batch_size` random samples.\n",
    "* `len(exp_replay)` - returns number of elements stored in replay buffer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ReplayBuffer():\n",
    "    def __init__(self, size):\n",
    "        \"\"\"\n",
    "        Create Replay buffer.\n",
    "        Parameters\n",
    "        ----------\n",
    "        size: int\n",
    "            Max number of transitions to store in the buffer. When the buffer\n",
    "            overflows the old memories are dropped.\n",
    "\n",
    "        Note: for this assignment you can pick any data structure you want.\n",
    "              If you want to keep it simple, you can store a list of tuples of (s, a, r, s') in self._storage\n",
    "              However you may find out there are faster and/or more memory-efficient ways to do so.\n",
    "        \"\"\"\n",
    "        self._storage = []\n",
    "        self._maxsize = size\n",
    "\n",
    "        # OPTIONAL: YOUR CODE\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self._storage)\n",
    "\n",
    "    def add(self, obs_t, action, reward, obs_tp1, done):\n",
    "        '''\n",
    "        Make sure, _storage will not exceed _maxsize. \n",
    "        Make sure, FIFO rule is being followed: the oldest examples has to be removed earlier\n",
    "        '''        \n",
    "        data = (obs_t, action, reward, obs_tp1, done)\n",
    "        storage = self._storage\n",
    "        maxsize = self._maxsize\n",
    "        <YOUR CODE>\n",
    "            # add data to storage\n",
    "\n",
    "    def sample(self, batch_size):\n",
    "        \"\"\"Sample a batch of experiences.\n",
    "        Parameters\n",
    "        ----------\n",
    "        batch_size: int\n",
    "            How many transitions to sample.\n",
    "        Returns\n",
    "        -------\n",
    "        obs_batch: np.array\n",
    "            batch of observations\n",
    "        act_batch: np.array\n",
    "            batch of actions executed given obs_batch\n",
    "        rew_batch: np.array\n",
    "            rewards received as results of executing act_batch\n",
    "        next_obs_batch: np.array\n",
    "            next set of observations seen after executing act_batch\n",
    "        done_mask: np.array\n",
    "            done_mask[i] = 1 if executing act_batch[i] resulted in\n",
    "            the end of an episode and 0 otherwise.\n",
    "        \"\"\"\n",
    "        storage = self._storage\n",
    "        <YOUR CODE>\n",
    "            # randomly generate batch_size integers\n",
    "            # to be used as indexes of samples\n",
    "            \n",
    "        <YOUR CODE>\n",
    "            # collect <s,a,r,s',done> for each index\n",
    "        \n",
    "        return <YOUR CODE>\n",
    "            # <states>, <actions>, <rewards>, <next_states>, <is_done>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_replay = ReplayBuffer(10)\n",
    "\n",
    "for _ in range(30):\n",
    "    exp_replay.add(env.reset()[0], env.action_space.sample(),\n",
    "                   1.0, env.reset()[0], done=False)\n",
    "\n",
    "obs_batch, act_batch, reward_batch, next_obs_batch, is_done_batch = exp_replay.sample(\n",
    "    5)\n",
    "\n",
    "assert len(exp_replay) == 10, \"experience replay size should be 10 because that's what maximum capacity is\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def play_and_record(initial_state, agent, env, exp_replay, n_steps=1):\n",
    "    \"\"\"\n",
    "    Play the game for exactly n steps, record every (s,a,r,s', done) to replay buffer. \n",
    "    Whenever game ends, add record with done=True and reset the game.\n",
    "    It is guaranteed that env has done=False when passed to this function.\n",
    "\n",
    "    :returns: return sum of rewards over time and the state in which the env stays\n",
    "    \"\"\"\n",
    "    s = initial_state\n",
    "    sum_rewards = 0\n",
    "\n",
    "    # Play the game for n_steps as per instructions above\n",
    "    for t in range(n_steps):\n",
    "        \n",
    "        # select action using policy with exploration\n",
    "        a = <YOUR CODE>       \n",
    "        \n",
    "        ns, r, terminated, truncated, _ = env.step(a)\n",
    "        \n",
    "        exp_replay.add(s, a, r, ns, terminated)\n",
    "        \n",
    "        s = env.reset()[0] if terminated or truncated else ns\n",
    "        \n",
    "        sum_rewards += r        \n",
    "\n",
    "    return sum_rewards, s"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#testing your code.\n",
    "exp_replay = ReplayBuffer(2000)\n",
    "actor = <YOUR ACTOR CLASS>(state_dim, action_dim).to(DEVICE)\n",
    "\n",
    "state, _ = env.reset()\n",
    "play_and_record(state, actor, env, exp_replay, n_steps=1000)\n",
    "\n",
    "# if you're using your own experience replay buffer, some of those tests may need correction.\n",
    "# just make sure you know what your code does\n",
    "assert len(exp_replay) == 1000, \"play_and_record should have added exactly 1000 steps, \"\\\n",
    "                                 \"but instead added %i\" % len(exp_replay)\n",
    "is_dones = list(zip(*exp_replay._storage))[-1]\n",
    "\n",
    "for _ in range(100):\n",
    "    obs_batch, act_batch, reward_batch, next_obs_batch, is_done_batch = exp_replay.sample(\n",
    "        10)\n",
    "    assert obs_batch.shape == next_obs_batch.shape == (10,) + (state_dim,)\n",
    "    assert act_batch.shape == (\n",
    "        10, action_dim), \"actions batch should have shape (10, 8) but is instead %s\" % str(act_batch.shape)\n",
    "    assert reward_batch.shape == (\n",
    "        10,), \"rewards batch should have shape (10,) but is instead %s\" % str(reward_batch.shape)\n",
    "    assert is_done_batch.shape == (\n",
    "        10,), \"is_done batch should have shape (10,) but is instead %s\" % str(is_done_batch.shape)\n",
    "    assert [int(i) in (0, 1)\n",
    "            for i in is_dones], \"is_done should be strictly True or False\"\n",
    "\n",
    "print(\"Well done!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialization\n",
    "\n",
    "Let's start initializing our algorithm. Here is our hyperparameters:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gamma=0.99                    # discount factor\n",
    "max_buffer_size = 10**5       # size of experience replay\n",
    "start_timesteps = 5000        # size of experience replay when start training\n",
    "timesteps_per_epoch=1         # steps in environment per step of network updates\n",
    "batch_size=128                # batch size for all optimizations\n",
    "max_grad_norm=10              # max grad norm for all optimizations\n",
    "tau=0.005                     # speed of updating target networks\n",
    "policy_update_freq=<>         # frequency of actor update; vanilla choice is 2 for TD3 or 1 for SAC\n",
    "alpha=0.1                     # temperature for SAC\n",
    "\n",
    "# iterations passed\n",
    "n_iterations = 0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here is our experience replay:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# experience replay\n",
    "exp_replay = ReplayBuffer(max_buffer_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here is our models: *two* critics and one actor."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# models to train\n",
    "actor = <YOUR ACTOR CLASS>(state_dim, action_dim).to(DEVICE)\n",
    "critic1 = Critic(state_dim, action_dim).to(DEVICE)\n",
    "critic2 = Critic(state_dim, action_dim).to(DEVICE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To stabilize training, we will require **target networks** - slow updating copies of our models. In **TD3**, both critics and actor have their copies, in **SAC** it is assumed that only critics require target copies while actor is always used fresh."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# target networks: slow-updated copies of actor and two critics\n",
    "target_critic1 = Critic(state_dim, action_dim).to(DEVICE)\n",
    "target_critic2 = Critic(state_dim, action_dim).to(DEVICE)\n",
    "target_actor = TD3_Actor(state_dim, action_dim).to(DEVICE)  # comment this line if you chose SAC\n",
    "\n",
    "# initialize them as copies of original models\n",
    "target_critic1.load_state_dict(critic1.state_dict())\n",
    "target_critic2.load_state_dict(critic2.state_dict())\n",
    "target_actor.load_state_dict(actor.state_dict())            # comment this line if you chose SAC "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In continuous control, target networks are usually updated using exponential smoothing:\n",
    "$$\\theta^{-} \\leftarrow \\tau \\theta + (1 - \\tau) \\theta^{-},$$\n",
    "where $\\theta^{-}$ are target network weights, $\\theta$ - fresh parameters, $\\tau$ - hyperparameter. This util function will do it:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def update_target_networks(model, target_model):\n",
    "    for param, target_param in zip(model.parameters(), target_model.parameters()):\n",
    "          target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we will have three optimization procedures to train our three models, so let's welcome our three Adams:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# optimizers: for every model we have\n",
    "opt_actor = torch.optim.Adam(actor.parameters(), lr=3e-4)\n",
    "opt_critic1 = torch.optim.Adam(critic1.parameters(), lr=3e-4)\n",
    "opt_critic2 = torch.optim.Adam(critic2.parameters(), lr=3e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# just to avoid writing this code three times\n",
    "def optimize(name, model, optimizer, loss):\n",
    "    '''\n",
    "    Makes one step of SGD optimization, clips norm with max_grad_norm and \n",
    "    logs everything into tensorboard\n",
    "    '''\n",
    "    loss = loss.mean()\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)\n",
    "    optimizer.step()\n",
    "\n",
    "    # logging\n",
    "    env.writer.add_scalar(name, loss.item(), n_iterations)\n",
    "    env.writer.add_scalar(name + \"_grad_norm\", grad_norm.item(), n_iterations)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Critic target computation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, let's discuss our losses for critic and actor.\n",
    "\n",
    "To train both critics we would like to minimize MSE using 1-step targets: for one sampled transition $(s, a, r, s')$ it should look something like this:\n",
    "$$y(s, a) = r + \\gamma V(s').$$\n",
    "\n",
    "How do we evaluate next state and compute $V(s')$? Well, technically Monte-Carlo estimation looks simple:\n",
    "$$V(s') \\approx Q(s', a')$$\n",
    "where (important!) $a'$ is a sample from our current policy $\\pi(a' \\mid s')$.\n",
    "\n",
    "But out actor $\\pi$ will be actually trained to search for actions $a'$ where our critic gives big estimates, and this straightforward approach leads to serious overesimation issues. We require some hacks. First, we will use target networks for $Q$ (and **TD3** also uses target network for $\\pi$). Second, we will use *two* critics and take minimum across their estimations:\n",
    "$$V(s') = \\min_{i = 1,2} Q^{-}_i(s', a'),$$\n",
    "where $a'$ is sampled from target policy $\\pi^{-}(a' \\mid s')$ in **TD3** and from fresh policy $\\pi(a' \\mid s')$ in **SAC**.\n",
    "\n",
    "###### And the last but not the least:\n",
    "* in **TD3** to compute $a'$ use *mode with clipped noise* that will prevent our policy from exploiting narrow peaks in our critic approximation;\n",
    "* in **SAC** add (estimation of) entropy bonus in next state $s'$:\n",
    "$$V(s') = \\min_{i = 1,2} Q^{-}_i(s', a') - \\alpha \\log \\pi (a' \\mid s')$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_critic_target(rewards, next_states, is_done):\n",
    "    '''\n",
    "    Important: use target networks for this method! Do not use \"fresh\" models except fresh policy in SAC!\n",
    "    input:\n",
    "        rewards - PyTorch tensor, (batch_size)\n",
    "        next_states - PyTorch tensor, (batch_size x features)\n",
    "        is_done - PyTorch tensor, (batch_size)\n",
    "    output:\n",
    "        critic target - PyTorch tensor, (batch_size)\n",
    "    '''\n",
    "    with torch.no_grad():\n",
    "        critic_target = <YOUR CODE>\n",
    "    \n",
    "    assert not critic_target.requires_grad, \"target must not require grad.\"\n",
    "    assert len(critic_target.shape) == 1, \"dangerous extra dimension in target?\"\n",
    "\n",
    "    return critic_target"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To train actor we want simply to maximize\n",
    "$$\\mathbb{E}_{a \\sim \\pi(a \\mid s)} Q(s, a) \\to \\max_{\\pi}$$\n",
    "\n",
    "* in **TD3**, because of deterministic policy, the expectation reduces:\n",
    "$$Q(s, \\pi(s)) \\to \\max_{\\pi}$$\n",
    "* in **SAC**, use reparametrization trick to compute gradients and also do not forget to add entropy regularizer to motivate policy to be as stochastic as possible:\n",
    "$$\\mathbb{E}_{a \\sim \\pi(a \\mid s)} Q(s, a) - \\alpha \\log \\pi(a \\mid s) \\to \\max_{\\pi}$$\n",
    "\n",
    "**Note:** We will use (fresh) critic1 here as Q-functon to \"exploit\". You can also use both critics and again take minimum across their estimations (this is done in original implementation of **SAC** and not done in **TD3**), but this seems to be not of high importance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_actor_loss(states):\n",
    "    '''\n",
    "    Returns actor loss on batch of states\n",
    "    input:\n",
    "        states - PyTorch tensor, (batch_size x features)\n",
    "    output:\n",
    "        actor loss - PyTorch tensor, (batch_size)\n",
    "    '''\n",
    "    # make sure you have gradients w.r.t. actor parameters\n",
    "    actions = <YOUR CODE>\n",
    "    \n",
    "    assert actions.requires_grad, \"actions must be differentiable with respect to policy parameters\"\n",
    "    \n",
    "    # compute actor loss\n",
    "    actor_loss = <YOUR CODE>\n",
    "    return actor_loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Pipeline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally combining all together and launching our algorithm. Your goal is to reach at least 1000 average reward during evaluation after training in this ant environment (*since this is a new hometask, this threshold might be updated, so at least just see if your ant learned to walk in the rendered simulation*).\n",
    "\n",
    "* rewards should rise more or less steadily in this environment. There can be some drops due to instabilities of algorithm, but it should eventually start rising after 100K-200K iterations. If no progress in reward is observed after these first 100K-200K iterations, there is a bug.\n",
    "* gradient norm appears to be quite big for this task, it is ok if it reaches 100-200 (we handled it with clip_grad_norm). Consider everything exploded if it starts growing exponentially, then there is a bug."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = <YOUR FAVOURITE RANDOM SEED>\n",
    "np.random.seed(seed)\n",
    "env.unwrapped.seed(seed)\n",
    "torch.manual_seed(seed);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.notebook import trange\n",
    "\n",
    "interaction_state = env.reset()\n",
    "random_actor = RandomActor()\n",
    "\n",
    "for n_iterations in trange(0, 1000000, timesteps_per_epoch):\n",
    "    # if experience replay is small yet, no training happens\n",
    "    # we also collect data using random policy to collect more diverse starting data\n",
    "    if len(exp_replay) < start_timesteps:\n",
    "        _, interaction_state = play_and_record(interaction_state, random_actor, env, exp_replay, timesteps_per_epoch)\n",
    "        continue\n",
    "        \n",
    "    # perform a step in environment and store it in experience replay\n",
    "    _, interaction_state = play_and_record(interaction_state, actor, env, exp_replay, timesteps_per_epoch)\n",
    "        \n",
    "    # sample a batch from experience replay\n",
    "    states, actions, rewards, next_states, is_done = exp_replay.sample(batch_size)\n",
    "    \n",
    "    # move everything to PyTorch tensors\n",
    "    states = torch.tensor(states, device=DEVICE, dtype=torch.float)\n",
    "    actions = torch.tensor(actions, device=DEVICE, dtype=torch.float)\n",
    "    rewards = torch.tensor(rewards, device=DEVICE, dtype=torch.float)\n",
    "    next_states = torch.tensor(next_states, device=DEVICE, dtype=torch.float)\n",
    "    is_done = torch.tensor(\n",
    "        is_done.astype('float32'),\n",
    "        device=DEVICE,\n",
    "        dtype=torch.float\n",
    "    )\n",
    "    \n",
    "    # losses\n",
    "    critic1_loss = <YOUR CODE>\n",
    "    optimize(\"critic1\", critic1, opt_critic1, critic1_loss)\n",
    "\n",
    "    critic2_loss = <YOUR CODE>\n",
    "    optimize(\"critic2\", critic2, opt_critic2, critic2_loss)\n",
    "\n",
    "    # actor update is less frequent in TD3\n",
    "    if n_iterations % policy_update_freq == 0:\n",
    "        actor_loss = <YOUR CODE>\n",
    "        optimize(\"actor\", actor, opt_actor, actor_loss)\n",
    "\n",
    "        # update target networks\n",
    "        update_target_networks(critic1, target_critic1)\n",
    "        update_target_networks(critic2, target_critic2)\n",
    "        update_target_networks(actor, target_actor)                     # comment this line if you chose SAC"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-16T18:41:47.560269Z",
     "start_time": "2020-09-16T18:41:47.546277Z"
    }
   },
   "outputs": [],
   "source": [
    "def evaluate(env, actor, n_games=1, t_max=1000):\n",
    "    '''\n",
    "    Plays n_games and returns rewards and rendered games\n",
    "    '''\n",
    "    rewards = []\n",
    "\n",
    "    for _ in range(n_games):\n",
    "        s, _ = env.reset()\n",
    "\n",
    "        R = 0\n",
    "        for _ in range(t_max):\n",
    "            # select action for final evaluation of your policy\n",
    "            action = <YOUR CODE>\n",
    "\n",
    "            assert (action.max() <= 1).all() and  (action.min() >= -1).all()\n",
    "\n",
    "            s, r, terminated, truncated, _ = env.step(action)\n",
    "\n",
    "            R += r\n",
    "\n",
    "            if terminated or truncated:\n",
    "                break\n",
    "\n",
    "        rewards.append(R)\n",
    "    return np.array(rewards)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-16T18:38:45.130920Z",
     "start_time": "2020-09-16T18:38:13.090472Z"
    }
   },
   "outputs": [],
   "source": [
    "# evaluation will take some time!\n",
    "sessions = evaluate(env, actor, n_games=20)\n",
    "score = sessions.mean()\n",
    "print(f\"Your score: {score}\")\n",
    "\n",
    "assert score >= 1000, \"Needs more training?\"\n",
    "print(\"Well done!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Record"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-16T18:43:19.559507Z",
     "start_time": "2020-09-16T18:43:19.522533Z"
    }
   },
   "outputs": [],
   "source": [
    "from gymnasium.wrappers import RecordVideo\n",
    "\n",
    "# let's hope this will work\n",
    "# don't forget to pray\n",
    "with gym.make(\"Ant-v4\", render_mode=\"rgb_array\") as env, RecordVideo(\n",
    "    env=env, video_folder=\"./videos\"\n",
    ") as env_monitor:\n",
    "    # note that t_max is 300, so collected reward will be smaller than 1000\n",
    "    evaluate(env_monitor, actor, n_games=1, t_max=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Show video. This may not work in some setups. If it doesn't\n",
    "# work for you, you can download the videos and view them locally.\n",
    "\n",
    "from pathlib import Path\n",
    "from base64 import b64encode\n",
    "from IPython.display import HTML\n",
    "import sys\n",
    "\n",
    "video_paths = sorted([s for s in Path('videos').iterdir() if s.suffix == '.mp4'])\n",
    "video_path = video_paths[-1]  # You can also try other indices\n",
    "\n",
    "if 'google.colab' in sys.modules:\n",
    "    # https://stackoverflow.com/a/57378660/1214547\n",
    "    with video_path.open('rb') as fp:\n",
    "        mp4 = fp.read()\n",
    "    data_url = 'data:video/mp4;base64,' + b64encode(mp4).decode()\n",
    "else:\n",
    "    data_url = str(video_path)\n",
    "\n",
    "HTML(\"\"\"\n",
    "<video width=\"640\" height=\"480\" controls>\n",
    "  <source src=\"{}\" type=\"video/mp4\">\n",
    "</video>\n",
    "\"\"\".format(data_url))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Report\n",
    "\n",
    "We'd like to collect some statistics about computational resources you spent on this task. Please, report:\n",
    "* which GPU or CPU you used: <YOUR ANSWER>\n",
    "* number of iterations you used for training: <YOUR ANSWER>\n",
    "* wall-clock time spent (on computation =D): <YOUR ANSWER>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
